#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Feb  9 15:17:48 2017

@author: kevin
"""

import tensorflow as tf
import numpy as np

class vgg16:
     def __init__(self, imgs, weights=None, sess=None):
          self.imgs = imgs
          self.convlayers()
          self.fc_layers()
          self.probs = tf.squeeze(tf.nn.softmax(self.fc3l))
          if weights is not None and sess is not None:
               self.load_weights(weights, sess)
               
     def convlayers(self):
          self.parameters = []

          # zero-mean input
          with tf.name_scope('preprocess') as scope:
               mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
               images = self.imgs - mean
               
          # conv1_1
          with tf.name_scope('conv1_1') as scope:
               kernel = tf.Variable(tf.truncated_normal([3,3,3,64], dtype=tf.float32, stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(images, kernel, [1,1,1,1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32), trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv1_1 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # conv1_2
          with tf.name_scope('conv1_2') as scope:
               kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 64], dtype=tf.float32,stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(self.conv1_1, kernel, [1, 1, 1, 1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv1_2 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # pool1
          self.pool1 = tf.nn.max_pool(self.conv1_2,
                                      ksize=[1,2,2,1],
                                      strides=[1,2,2,1],               
                                      padding='SAME',
                                      name='pool1')
          # conv2_1
          with tf.name_scope('conv2_1') as scope:
               kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(self.pool1, kernel, [1, 1, 1, 1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),
                                 trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv2_1 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # conv2_2
          with tf.name_scope('conv2_2') as scope:
               kernel = tf.Variable(tf.truncated_normal([3, 3, 128, 128], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(self.conv2_1, kernel, [1, 1, 1, 1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),
                                 trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv2_2 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # pool2
          self.pool2 = tf.nn.max_pool(self.conv2_2,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool2')
          
          # conv3_1
          with tf.name_scope('conv3_1') as scope:
               kernel = tf.Variable(tf.truncated_normal([3, 3, 128, 256], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(self.pool2, kernel, [1, 1, 1, 1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
                                 trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv3_1 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # conv3_2
          with tf.name_scope('conv3_2') as scope:
               kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(self.conv3_1, kernel, [1, 1, 1, 1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
                                 trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv3_2 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # conv3_3
          with tf.name_scope('conv3_3') as scope:
               kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
               conv = tf.nn.conv2d(self.conv3_2, kernel, [1, 1, 1, 1], padding='SAME')
               biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
                                 trainable=True, name='biases')
               out = tf.nn.bias_add(conv, biases)
               self.conv3_3 = tf.nn.relu(out, name=scope)
               self.parameters += [kernel, biases]

          # pool3
          self.pool3 = tf.nn.max_pool(self.conv3_3,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool3')
          
          # conv4_1
          with tf.name_scope('conv4_1') as scope:
            kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 512], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
            conv = tf.nn.conv2d(self.pool3, kernel, [1, 1, 1, 1], padding='SAME')
            biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32),
                                 trainable=True, name='biases')
            out = tf.nn.bias_add(conv, biases)
            self.conv4_1 = tf.nn.relu(out, name=scope)
            self.parameters += [kernel, biases]

          # conv4_2
          with tf.name_scope('conv4_2') as scope:
            kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
            conv = tf.nn.conv2d(self.conv4_1, kernel, [1, 1, 1, 1], padding='SAME')
            biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32),
                                 trainable=True, name='biases')
            out = tf.nn.bias_add(conv, biases)
            self.conv4_2 = tf.nn.relu(out, name=scope)
            self.parameters += [kernel, biases]

          # conv4_3
          with tf.name_scope('conv4_3') as scope:
            kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
            conv = tf.nn.conv2d(self.conv4_2, kernel, [1, 1, 1, 1], padding='SAME')
            biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32),
                                 trainable=True, name='biases')
            out = tf.nn.bias_add(conv, biases)
            self.conv4_3 = tf.nn.relu(out, name=scope)
            self.parameters += [kernel, biases]

          # pool4
          self.pool4 = tf.nn.max_pool(self.conv4_3,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool4')
          
          # conv5_1
          with tf.name_scope('conv5_1') as scope:
            kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
            conv = tf.nn.conv2d(self.pool4, kernel, [1, 1, 1, 1], padding='SAME')
            biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32),
                                 trainable=True, name='biases')
            out = tf.nn.bias_add(conv, biases)
            self.conv5_1 = tf.nn.relu(out, name=scope)
            self.parameters += [kernel, biases]

          # conv5_2
          with tf.name_scope('conv5_2') as scope:
            kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
            conv = tf.nn.conv2d(self.conv5_1, kernel, [1, 1, 1, 1], padding='SAME')
            biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32),
                                 trainable=True, name='biases')
            out = tf.nn.bias_add(conv, biases)
            self.conv5_2 = tf.nn.relu(out, name=scope)
            self.parameters += [kernel, biases]

          # conv5_3
          with tf.name_scope('conv5_3') as scope:
            kernel = tf.Variable(tf.truncated_normal([3, 3, 512, 512], dtype=tf.float32,
                                                     stddev=1e-1), name='weights')
            conv = tf.nn.conv2d(self.conv5_2, kernel, [1, 1, 1, 1], padding='SAME')
            biases = tf.Variable(tf.constant(0.0, shape=[512], dtype=tf.float32),
                                 trainable=True, name='biases')
            out = tf.nn.bias_add(conv, biases)
            self.conv5_3 = tf.nn.relu(out, name=scope)
            self.parameters += [kernel, biases]

          # pool5
          self.pool5 = tf.nn.max_pool(self.conv5_3,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool5')
          
     def fc_layers(self):
          # fc1
          with tf.name_scope('fc1') as scope:
               shape = int(np.prod(self.pool5.get_shape()[1:]))
               fc1w = tf.Variable(tf.truncated_normal([shape, 4096],
                                                      dtype=tf.float32,
                                                      stddev=1e-1), name='weights')
               fc1b = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32),
                                                      trainable=True, name='biases')
               pool5_flat = tf.reshape(self.pool5, [-1, shape])
               fc1l = tf.nn.bias_add(tf.matmul(pool5_flat, fc1w), fc1b)
               self.fc1 = tf.nn.relu(fc1l)
               self.parameters += [fc1w, fc1b]

          # fc2
          with tf.name_scope('fc2') as scope:
               fc2w = tf.Variable(tf.truncated_normal([4096, 4096],
                                                         dtype=tf.float32,
                                                         stddev=1e-1), name='weights')
               fc2b = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32),
                                 trainable=True, name='biases')
               fc2l = tf.nn.bias_add(tf.matmul(self.fc1, fc2w), fc2b)
               self.fc2 = tf.nn.relu(fc2l)
               self.parameters += [fc2w, fc2b]

          # fc3
          with tf.name_scope('fc3') as scope:
               fc3w = tf.Variable(tf.truncated_normal([4096, 1000],
                                                         dtype=tf.float32,
                                                         stddev=1e-1), name='weights')
               fc3b = tf.Variable(tf.constant(1.0, shape=[1000], dtype=tf.float32),
                                 trainable=True, name='biases')
               self.fc3l = tf.nn.bias_add(tf.matmul(self.fc2, fc3w), fc3b)
               self.parameters += [fc3w, fc3b]

     def load_weights(self, weight_file, sess):
          weights = np.load(weight_file)
          keys = sorted(weights.keys())
          for i, k in enumerate(keys):
               # print i,k,np.shape(weights[k])
               sess.run(self.parameters[i].assign(weights[k]))
               
